#!/usr/bin/env python
# https://github.com/dranaju/project

import rospy
import os
import json
import numpy as np
import random
import time
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
from collections import deque
from std_msgs.msg import Float32
from src.turtlebot3_dqn.environment_TD3 import Env
import torch
import torch.nn.functional as F
import gc
import torch.nn as nn
import math

#---Directory Path---#
dirPath = os.path.dirname(os.path.realpath(__file__))

#---Functions to make network updates---#
def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data*(1.0 - tau)+ param.data*tau)

def hard_update(target,source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

#---Ornstein-Uhlenbeck Noise for action---#

class ActionNoise:
    # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
    def __init__(self, action_dim, mu=0, theta=0.15, sigma=0.2):
        self.action_dim = action_dim
        self.mu = mu
        self.theta = theta
        self.sigma = sigma
        self.X = np.ones(self.action_dim)*self.mu

    def reset(self):
        self.X = np.ones(self.action_dim)*self.mu

    def sample(self):
        dx = self.theta*(self.mu - self.X)
        dx = dx + self.sigma*np.random.randn(len(self.X))
        self.X = self.X + dx
        # print('aqu2i' + str(self.X))
        return self.X

#---Critic--#

EPS = 0.003
def fanin_init(size, fanin=None):
    fanin = fanin or size[0]
    v = 1./np.sqrt(fanin)
    return torch.Tensor(size).uniform_(-v,v)

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        self.state_dim = state_dim = state_dim
        self.action_dim = action_dim

        self.fc1 = nn.Linear(state_dim, 256)
        self.fc1.weight.data = fanin_init(self.fc1.weight.data.size())

        self.fa1 = nn.Linear(action_dim, 256)
        self.fa1.weight.data = fanin_init(self.fa1.weight.data.size())

        self.fca1 = nn.Linear(512, 512)
        self.fca1.weight.data = fanin_init(self.fca1.weight.data.size())

        self.fca2 = nn.Linear(512, 1)
        self.fca2.weight.data.uniform_(-EPS, EPS)

    def forward(self, state, action):
        xs = torch.relu(self.fc1(state))
        xa = torch.relu(self.fa1(action))
        x = torch.cat((xs,xa), dim=1)
        x = torch.relu(self.fca1(x))
        vs = self.fca2(x)
        return vs

# second critic
class Critic_2(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic_2, self).__init__()
        # Marcel: Q2 Architecture
        self.state_dim = state_dim = state_dim
        self.action_dim = action_dim

        self.fc1_2 = nn.Linear(state_dim, 256)
        self.fc1_2.weight.data = fanin_init(self.fc1_2.weight.data.size())

        self.fa1_2 = nn.Linear(action_dim, 256)
        self.fa1_2.weight.data = fanin_init(self.fa1_2.weight.data.size())

        self.fca1_2 = nn.Linear(512, 512)
        self.fca1_2.weight.data = fanin_init(self.fca1_2.weight.data.size())

        self.fca2_2 = nn.Linear(512, 1)
        self.fca2_2.weight.data.uniform_(-EPS, EPS)

    def forward_2(self, state, action):
        # return for 2. critic
        xs_2 = torch.relu(self.fc1_2(state))
        xa_2 = torch.relu(self.fa1_2(action))
        x_2 = torch.cat((xs_2,xa_2), dim=1)
        x_2 = torch.relu(self.fca1_2(x_2))
        vs_2 = self.fca2_2(x_2)
        return vs_2

#---Actor---#

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, action_limit_v, action_limit_w):
        super(Actor, self).__init__()
        self.state_dim = state_dim = state_dim
        self.action_dim = action_dim
        self.action_limit_v = action_limit_v
        self.action_limit_w = action_limit_w

        self.fa1 = nn.Linear(state_dim, 512)
        self.fa1.weight.data = fanin_init(self.fa1.weight.data.size())

        self.fa2 = nn.Linear(512, 512)
        self.fa2.weight.data = fanin_init(self.fa2.weight.data.size())

        self.fa3 = nn.Linear(512, action_dim)
        self.fa3.weight.data.uniform_(-EPS,EPS)

    def forward(self, state):
        x = torch.relu(self.fa1(state))
        x = torch.relu(self.fa2(x))
        action = self.fa3(x)
        if state.shape == torch.Size([364]):
            action[0] = torch.sigmoid(action[0])*self.action_limit_v
            action[1] = torch.tanh(action[1])*self.action_limit_w
        else:
            action[:,0] = torch.sigmoid(action[:,0])*self.action_limit_v
            action[:,1] = torch.tanh(action[:,1])*self.action_limit_w
        return action

#---Memory Buffer---#

class MemoryBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)
        self.maxSize = size
        self.len = 0

    def sample(self, count):
        batch = []
        count = min(count, self.len)
        batch = random.sample(self.buffer, count)

        s_array = np.float32([array[0] for array in batch])
        a_array = np.float32([array[1] for array in batch])
        r_array = np.float32([array[2] for array in batch])
        new_s_array = np.float32([array[3] for array in batch])

        return s_array, a_array, r_array, new_s_array

    def len(self):
        return self.len

    def add(self, s, a, r, new_s):
        transition = (s, a, r, new_s)
        self.len += 1
        if self.len > self.maxSize:
            self.len = self.maxSize
        self.buffer.append(transition)

#---Where the train is made---#

BATCH_SIZE = 128
LEARNING_RATE = 0.001
GAMMA = 0.99
TAU = 0.003
# TAU = 0.001

class Trainer:

    def __init__(self, state_dim, action_dim, action_limit_v, action_limit_w, ram):

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_limit_v = action_limit_v
        self.action_limit_w = action_limit_w
        self.ram = ram
        self.noise = ActionNoise(self.action_dim)

        self.actor = Actor(self.state_dim, self.action_dim, self.action_limit_v, self.action_limit_w)
        self.target_actor = Actor(self.state_dim, self.action_dim, self.action_limit_v, self.action_limit_w)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), LEARNING_RATE)

        self.critic = Critic(self.state_dim, self.action_dim)
        self.target_critic = Critic(self.state_dim, self.action_dim)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), LEARNING_RATE)

        # second critic ma
        self.critic_2 = Critic_2(self.state_dim, self.action_dim)
        self.target_critic_2 = Critic_2(self.state_dim, self.action_dim)
        self.critic_optimizer_2 = torch.optim.Adam(self.critic_2.parameters(), LEARNING_RATE)

        self.pub_qvalue = rospy.Publisher('qvalue', Float32, queue_size=5)
        self.qvalue = Float32()

        self.frequency = 5
        self.delay = 0
        self.dirPath = os.path.dirname(os.path.realpath(__file__))
        self.dirPath = self.dirPath.replace('turtlebot3_dqn/nodes', 'turtlebot3_dqn/save_model_TD3/')


        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

    def get_exploitation_action(self,state):
        state = torch.from_numpy(state)
        action = self.target_actor.forward(state).detach()
        return action.data.numpy()

    def get_exploration_action(self, state):
        # exploration noise ma
        state = torch.from_numpy(state)
        action = self.actor.forward(state).detach()
        noise = torch.zeros(2)
        # v_max
        noise[0] = torch.normal(torch.zeros(1), 0.01)
        # w_max
        noise[1] = torch.normal(torch.zeros(1), 0.1)

        noise[0] = torch.clamp(noise[0], -0.02, 0.02)
        noise[1] = torch.clamp(noise[1], -0.2, 0.2)

        new_action = action + noise

        new_action = action.data.numpy() #+ noise
        return new_action

    def optimizer(self):

        # policy update frequently MA
        self.delay += 1

        s_sample, a_sample, r_sample, new_s_sample = ram.sample(BATCH_SIZE)

        s_sample = torch.from_numpy(s_sample)
        a_sample = torch.from_numpy(a_sample)
        r_sample = torch.from_numpy(r_sample)
        new_s_sample = torch.from_numpy(new_s_sample)

        #-------------- optimize critic

        a_target = self.target_actor.forward(new_s_sample).detach()
        next_value_1 = torch.squeeze(self.target_critic.forward(new_s_sample, a_target).detach())
        next_value_2 = torch.squeeze(self.target_critic_2.forward_2(new_s_sample, a_target).detach())
        # take minimum value from both critics: no fooling
        next_value = torch.min(next_value_1, next_value_2)
        y_expected = r_sample + GAMMA*next_value

        # both critic predictions
        y_predicted = torch.squeeze(self.critic.forward(s_sample, a_sample))
        y_predicted_2 =  torch.squeeze(self.critic_2.forward_2(s_sample, a_sample))

        #-------Publisher of Vs------
        self.qvalue = y_predicted.detach()
        self.pub_qvalue.publish(torch.max(self.qvalue))

        #----------------------------Optimize Critic
        loss_critic = F.smooth_l1_loss(y_predicted, y_expected)
        loss_critic_2 = F.smooth_l1_loss(y_predicted_2, y_expected)

        self.critic_optimizer.zero_grad()
        self.critic_optimizer_2.zero_grad()

        loss_critic.backward()
        loss_critic_2.backward()

        self.critic_optimizer.step()
        self.critic_optimizer_2.step()
        
        # delayed updates Ma
        if self.delay % self.frequency == 0:
            #------------ optimize actor
            pred_a_sample = self.actor.forward(s_sample)
            loss_actor = -1*torch.sum(self.critic.forward(s_sample, pred_a_sample))

            self.actor_optimizer.zero_grad()
            loss_actor.backward()
            self.actor_optimizer.step()

            soft_update(self.target_actor, self.actor, TAU)
            soft_update(self.target_critic, self.critic, TAU)
            soft_update(self.target_critic_2, self.critic_2, TAU)

    def save_models(self, episode_count):
        torch.save(self.target_actor.state_dict(), self.dirPath + str(episode_count)+ '_actor.pt')
        torch.save(self.target_critic.state_dict(), self.dirPath + str(episode_count)+ '_critic.pt')
        torch.save(self.target_critic_2.state_dict(), self.dirPath +str(episode_count)+ '_critic_2.pt')


    def load_models(self, episode):
        self.actor.load_state_dict(torch.load(dirPath + '/Models/real1/'+str(episode)+ '_actor.pt'))
        self.critic.load_state_dict(torch.load(dirPath + '/Models/real1/'+str(episode)+ '_critic.pt'))
        self.critic_2.load_state_dict(torch.load(dirPath + '/Models/real1/'+str(episode)+ '_critic_2.pt'))

        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)
        hard_update(self.target_critic_2, self.critic_2)

#---Run agent---#

is_training = True

if is_training:
    exploration_rate = 1
    max_exploration_rate = 1
    min_exploration_rate = 0.05
else:
    exploration_rate = 0.05
    max_exploration_rate = 0.05
    min_exploration_rate = 0.05

exploration_decay_rate = 0.001

MAX_EPISODES = 10001
MAX_STEPS = 2000
MAX_BUFFER = 1000000
rewards_all_episodes = []

STATE_DIMENSION = 364
ACTION_DIMENSION = 2
ACTION_V_MAX = 0.15 # m/s
ACTION_W_MAX = 1.5 # rad/s

#Marcel hint: noise adding
if is_training:
    var_v = ACTION_V_MAX*0.20
    var_w = ACTION_W_MAX*2*0.20
else:
    var_v = ACTION_V_MAX*0.10
    var_w = ACTION_W_MAX*0.10

ram = MemoryBuffer(MAX_BUFFER)
trainer = Trainer(STATE_DIMENSION, ACTION_DIMENSION, ACTION_V_MAX, ACTION_W_MAX, ram)
if __name__ == '__main__':
    start_time = time.time()
    rospy.init_node('TD3_1')
    pub_result = rospy.Publisher('result', Float32, queue_size=5)
    result = Float32()
    env = Env()
    with open('/home/ifl/catkin_ws/src/turtlebot3_machine_learning/turtlebot3_dqn/results_experiments/TD3_world_dynamic_1_bis.csv', 'w+') as saveFile:

        saveFile.write("Q-value,Crashes,Goals,Time" + '\n')
        for ep in range(MAX_EPISODES):
            done = False
            state = env.reset()
            print('Episode: ' + str(ep))
            score = 0
            rewards_current_episode = 0
            for step in range(MAX_STEPS):
                #print(step)
                state = np.float32(state)

                if is_training:
                    action = trainer.get_exploration_action(state)
                    action[0] = np.clip(np.random.normal(action[0], var_v), 0., ACTION_V_MAX)
                    action[1] = np.clip(np.random.normal(action[1], var_w), -ACTION_W_MAX, ACTION_W_MAX)

                if not is_training:
                    action = trainer.get_exploitation_action(state)

                next_state, reward, done, crashes, goals = env.step(action)
                # print('action ', action,' ,r ',reward)
                score += reward
                rewards_current_episode += reward
                next_state = np.float32(next_state)
                ram.add(state, action, reward, next_state)
                state = next_state
                if ram.len >=  MAX_STEPS and is_training:
                    var_v = max([var_v*0.99999, 0.10*ACTION_V_MAX])
                    var_w = max([var_w*0.99999, 0.10*ACTION_W_MAX])
                    trainer.optimizer()

                if done :
                    m, s = divmod(int(time.time() - start_time), 60)
                    h, m = divmod(m, 60)
                    saveFile.write(str(score)+ "," + str(goals) + "," + str(crashes) +","+ str(h)+"."+str(m)+"."+str(s) + '\n')  # write each result to new line
                    print "score: ", score, "crashes: ", crashes, "goals: ", goals, "time: ",  h, m, s
                    saveFile.flush()
                    rewards_all_episodes.append(rewards_current_episode)
                    result = rewards_current_episode
                    pub_result.publish(result)
                    m, s = divmod(int(time.time() - start_time), 60)
                    h, m = divmod(m, 60)
                    break
            exploration_rate = (min_exploration_rate +
                    (max_exploration_rate - min_exploration_rate)* np.exp(-exploration_decay_rate*ep))
            gc.collect()
            if ep%100 == 0:
                trainer.save_models(ep)
